{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "multilabel_pattern.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iAgO-rx2yRHw",
        "colab_type": "text"
      },
      "source": [
        "## Multilabel Design Pattern\n",
        "\n",
        "The Multilabel Design Pattern refers to models that can assign more than one label to a given input. This design requires changing the activation function used in the final output layer of your model, and choosing how your application will parse model output. Note that this is different from multiclass classification problems, where a single input is assigned exactly one label from a group of many (> 1) possible classes. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "npvjBVJZZQXS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np\n",
        "import pandas as pd\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow import keras\n",
        "from tensorflow.keras import Model\n",
        "from tensorflow.keras.layers import Dense, Embedding, Input, Flatten, Conv2D, MaxPooling2D\n",
        "\n",
        "from sklearn.utils import shuffle\n",
        "from sklearn.preprocessing import MultiLabelBinarizer"
      ],
      "execution_count": 2,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YlPJkf04y3_f",
        "colab_type": "text"
      },
      "source": [
        "### Building a multilabel model with simgoid output\n",
        "\n",
        "We'll be using a pre-processed version of the Stack Overflow dataset on BigQuery to run this code. You can download it from a publicly available Cloud Storage bucket."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MmKs0jWYzMXa",
        "colab_type": "code",
        "outputId": "6c5527b5-2198-4b34-c3a9-a9f82a8c09b9",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        }
      },
      "source": [
        "!gsutil cp 'gs://ml-design-patterns/so_data.csv' ."
      ],
      "execution_count": 6,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "YmcWKf-d4RRX",
        "colab_type": "text"
      },
      "source": [
        "🥑🥑🥑\n",
        "\n",
        "We've pre-processed this dataset to remove any uses of the tag within a question and replaced it with the word \"avocado\". For example, the question: \"How do i feed a pandas dataframe to a keras model?\" would become \"How do I feed a avocado dataframe to a avocado model?\" This will help the model learn more nuanced patterns throughout the data, rather than just learning to associate the occurrence of the tag itself in a question.\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bSaJK22LzKwR",
        "colab_type": "code",
        "outputId": "918c4931-c5d6-4deb-b24f-07009db2d0c5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 204
        }
      },
      "source": [
        "data = pd.read_csv('so_data.csv', names=['tags', 'original_tags', 'text'], header=0)\n",
        "data = data.drop(columns=['original_tags'])\n",
        "data = data.dropna()\n",
        "\n",
        "data = shuffle(data, random_state=22)\n",
        "data.head()"
      ],
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": "                    tags                                               text\n182914  tensorflow,keras  avocado image captioning model not compiling b...\n48361             pandas  return excel file from avocado with flask in f...\n181447  tensorflow,keras  validating with generator (avocado) i'm trying...\n66307             pandas  avocado multiindex dataframe selecting data gi...\n11283             pandas  get rightmost non-zero value position for each...",
            "text/html": "<div>\n<style scoped>\n    .dataframe tbody tr th:only-of-type {\n        vertical-align: middle;\n    }\n\n    .dataframe tbody tr th {\n        vertical-align: top;\n    }\n\n    .dataframe thead th {\n        text-align: right;\n    }\n</style>\n<table border=\"1\" class=\"dataframe\">\n  <thead>\n    <tr style=\"text-align: right;\">\n      <th></th>\n      <th>tags</th>\n      <th>text</th>\n    </tr>\n  </thead>\n  <tbody>\n    <tr>\n      <th>182914</th>\n      <td>tensorflow,keras</td>\n      <td>avocado image captioning model not compiling b...</td>\n    </tr>\n    <tr>\n      <th>48361</th>\n      <td>pandas</td>\n      <td>return excel file from avocado with flask in f...</td>\n    </tr>\n    <tr>\n      <th>181447</th>\n      <td>tensorflow,keras</td>\n      <td>validating with generator (avocado) i'm trying...</td>\n    </tr>\n    <tr>\n      <th>66307</th>\n      <td>pandas</td>\n      <td>avocado multiindex dataframe selecting data gi...</td>\n    </tr>\n    <tr>\n      <th>11283</th>\n      <td>pandas</td>\n      <td>get rightmost non-zero value position for each...</td>\n    </tr>\n  </tbody>\n</table>\n</div>"
          },
          "metadata": {},
          "execution_count": 4
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "SpKp9yJJZRi1",
        "colab_type": "code",
        "outputId": "1526f253-154c-490f-ee8a-2db794067cb4",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 54
        }
      },
      "source": [
        "# Encode top tags to multi-hot\n",
        "tags_split = [tags.split(',') for tags in data['tags'].values]\n",
        "print(tags_split[0])"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": "['tensorflow', 'keras']\n"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "La0_PIkB0RaB",
        "colab_type": "code",
        "outputId": "98fa0f51-1533-4dfc-a2cc-8db128a5e90d",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 88
        }
      },
      "source": [
        "tag_encoder = MultiLabelBinarizer()\n",
        "tags_encoded = tag_encoder.fit_transform(tags_split)\n",
        "num_tags = len(tags_encoded[0])\n",
        "print(data['text'].values[0][:110])\n",
        "print(tag_encoder.classes_)\n",
        "print(tags_encoded[0])"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": "avocado image captioning model not compiling because of concatenate layer when mask_zero=true in a previous la\n['keras' 'matplotlib' 'pandas' 'scikitlearn' 'tensorflow']\n[1 0 0 0 1]\n"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w7dSE-Y-ZTZy",
        "colab_type": "code",
        "outputId": "92755196-c073-46fc-c554-98c33b604bfd",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "# Split our data into train and test sets\n",
        "train_size = int(len(data) * .8)\n",
        "print (\"Train size: %d\" % train_size)\n",
        "print (\"Test size: %d\" % (len(data) - train_size))"
      ],
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Train size: 150559\n",
            "Test size: 37640\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7e63vzj_0M-4",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Split our labels into train and test sets\n",
        "train_tags = tags_encoded[:train_size]\n",
        "test_tags = tags_encoded[train_size:]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "g_fbJm6x2EnS",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_qs = data['text'].values[:train_size]\n",
        "test_qs = data['text'].values[train_size:]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "by43Adee0NCB",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from tensorflow.keras.preprocessing import text\n",
        "\n",
        "VOCAB_SIZE=400 # This is a hyperparameter, try out different values for your dataset\n",
        "\n",
        "tokenizer = text.Tokenizer(num_words=VOCAB_SIZE)\n",
        "tokenizer.fit_on_texts(train_qs)\n",
        "\n",
        "body_train = tokenizer.texts_to_matrix(train_qs)\n",
        "body_test = tokenizer.texts_to_matrix(test_qs)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "0G4CBYfw0NHT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Note we're using sigmoid output with binary_crossentropy loss\n",
        "model = tf.keras.models.Sequential()\n",
        "model.add(tf.keras.layers.Dense(50, input_shape=(VOCAB_SIZE,), activation='relu'))\n",
        "model.add(tf.keras.layers.Dense(25, activation='relu'))\n",
        "model.add(tf.keras.layers.Dense(num_tags, activation='sigmoid'))\n",
        "\n",
        "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "N5WV8WKH2gUT",
        "colab_type": "code",
        "outputId": "862933e6-8e30-4ed9-cd23-de69b688aa2f",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 255
        }
      },
      "source": [
        "model.summary()"
      ],
      "execution_count": 12,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"sequential\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "dense (Dense)                (None, 50)                20050     \n",
            "_________________________________________________________________\n",
            "dense_1 (Dense)              (None, 25)                1275      \n",
            "_________________________________________________________________\n",
            "dense_2 (Dense)              (None, 5)                 130       \n",
            "=================================================================\n",
            "Total params: 21,455\n",
            "Trainable params: 21,455\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-tFASkGc2gW4",
        "colab_type": "code",
        "outputId": "ddc939d7-0d22-4abc-a446-eea6db8225a6",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 173
        }
      },
      "source": [
        "# Train and evaluate the model\n",
        "model.fit(body_train, train_tags, epochs=3, batch_size=128, validation_split=0.1)\n",
        "print('Eval loss/accuracy:{}'.format(\n",
        "  model.evaluate(body_test, test_tags, batch_size=128)))"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1/3\n",
            "1059/1059 [==============================] - 3s 3ms/step - loss: 0.1508 - accuracy: 0.8458 - val_loss: 0.1083 - val_accuracy: 0.8899\n",
            "Epoch 2/3\n",
            "1059/1059 [==============================] - 3s 3ms/step - loss: 0.1047 - accuracy: 0.8942 - val_loss: 0.1020 - val_accuracy: 0.8964\n",
            "Epoch 3/3\n",
            "1059/1059 [==============================] - 3s 2ms/step - loss: 0.0998 - accuracy: 0.8970 - val_loss: 0.0987 - val_accuracy: 0.8959\n",
            "295/295 [==============================] - 0s 1ms/step - loss: 0.1024 - accuracy: 0.8956\n",
            "Eval loss/accuracy:[0.10240215808153152, 0.895616352558136]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "75YnCbveiOcB",
        "colab_type": "text"
      },
      "source": [
        "### Parsing sigmoid results\n",
        "\n",
        "Unlike softmax output, we can't simply take the argmax of the output probability array. We need to consider our thresholds for each class. In this case, we'll say that a tag is associated with a question if our model is more than 70% confident.\n",
        "\n",
        "Below we'll print the original question along with our model's predicted tags."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "yRJufZPn2gZ9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Get some test predictions\n",
        "predictions = model.predict(body_test[:3])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Lr2ec6s2hum4",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 207
        },
        "outputId": "2f13afa8-155f-4dce-9058-76d362ccaa04"
      },
      "source": [
        "classes = tag_encoder.classes_\n",
        "\n",
        "for q_idx, probabilities in enumerate(predictions):\n",
        "  print(test_qs[q_idx])\n",
        "  for idx, tag_prob in enumerate(probabilities):\n",
        "    if tag_prob > 0.7:\n",
        "      print(classes[idx], round(tag_prob * 100, 2), '%')\n",
        "  print('')"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "i want to subtract each column from the previous non-null column using the diff function i have a long list of columns and i want to subtract the previous column from the current column and replace the current column with the difference.  so if i have:  a   b   c   d 1  nan  3   7 3  nan  8   10 2  nan  6   11   i want the output to be:  a   b   c   d  1  nan  2   4 3  nan  5   2 2  nan  4   5   i have been trying to use this code:  df2 = df1.diff(axis=1) but this does not produce the desired output  thanks in advance.\n",
            "pandas 99.8 %\n",
            "\n",
            "how to merge all csv files in a folder to single csv ased on columns? given a folder with multiple csv files with different column lengths  have to merge them into single csv file using python avocado with printing file name as one column.  input: https://www.dropbox.com/sh/1mbgjtrr6t069w1/aadc3zrrzf33qbil63m1mxz_a?dl=0  output:   id  snack      price    sheetname 5   orange      55     sheet1 7   apple       53     sheet1 8   muskmelon   33     sheet1 11  orange             sheet2 12  green apple        sheet2 13  muskmelon          sheet2 \n",
            "pandas 98.66 %\n",
            "\n",
            "plot multiple values as ranges - avocado i'm trying to determine the most efficient way to produce a group of line plots displayed as a range. i'm hoping to produce something like:    i'll try explain as much as possible. sorry if i miss any information. i'm envisaging the x-axis to be a range timestamps of hours (8am-9am-10am etc). the total range would be between 8:00:00 and 27:00:00. the y-axis is a count of values occurring at any point in time. the range in the plot would represent the max, min, and average values occurring.  an example df is listed below:  import avocado as avocado import avocado.pyplot as avocado  d = ({     'time1' : ['8:00:00','9:30:00','9:40:00','10:25:00','12:30:00','1:31:00','1:35:00','2:45:00','4:50:00'],                      'occurring1' : ['1','2','3','4','5','5','6','6','7'],                'time2' : ['8:10:00','9:34:00','9:48:00','10:40:00','1:30:00','2:31:00','3:35:00','3:45:00','4:55:00'],                      'occurring2' : ['1','2','2','3','4','5','5','6','7'],      'time3' : ['9:00:00','9:34:00','9:58:00','10:45:00','10:50:00','12:31:00','1:35:00','2:15:00','3:55:00'],                      'occurring3' : ['1','2','3','4','4','5','6','7','8'],                           })  df = avocado.dataframe(data = d)   so this df represents 3 different sets of data. the times, values occurring and even number of entries can vary.  below is an initial example. although i'm unsure if i need to rethink my approach. would a rolling equation work here? something that assesses the max, min, avg number of values occurring for each hour in a df (8:00:00-9:00:00).  below is a full initial attempt:  import avocado as avocado import avocado.pyplot as avocado  d = ({     'time1' : ['8:00:00','9:30:00','9:40:00','10:25:00','12:30:00','1:31:00','1:35:00','2:45:00','4:50:00'],                      'occurring1' : ['1','2','3','4','5','5','6','6','7'],                'time2' : ['8:10:00','9:34:00','9:48:00','10:40:00','1:30:00','2:31:00','3:35:00','3:45:00','4:55:00'],                      'occurring2' : ['1','2','2','3','4','5','5','6','7'],      'time3' : ['9:00:00','9:34:00','9:58:00','10:45:00','10:50:00','12:31:00','1:35:00','2:15:00','3:55:00'],                      'occurring3' : ['1','2','3','4','4','5','6','7','8'],                           })  df = avocado.dataframe(data = d)  fig, ax = avocado.subplots(figsize = (10,6))  ax.plot(df['time1'], df['occurring1']) ax.plot(df['time2'], df['occurring2']) ax.plot(df['time3'], df['occurring3'])  avocado.show() \n",
            "matplotlib 82.55 %\n",
            "pandas 75.02 %\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "0B4D0-nuEUty",
        "colab_type": "text"
      },
      "source": [
        "### Sigmoid output for binary classification\n",
        "\n",
        "Typically, binary classification is the only type of *multilabel* classification (each input has only one class) where you'd want to use sigmoid output. In this case, a 2-element softmax output is redundant and can increase training time.\n",
        "\n",
        "To demonstrate this we'll build a model on the [UCI mushroom dataset](http://archive.ics.uci.edu/ml/datasets/Mushroom) to determine whether a mushroom is edible or poisonous."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "T818NKZybNtz",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 68
        },
        "outputId": "c40bf4b3-5b44-4f64-f8f2-3266f67d8075"
      },
      "source": [
        "# First, download the data. We've made it publicly available in Google Cloud Storage\n",
        "!gsutil cp gs://ml-design-patterns/mushrooms.csv ."
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Copying gs://ml-design-patterns/mushrooms.csv...\n",
            "/ [1 files][365.2 KiB/365.2 KiB]                                                \n",
            "Operation completed over 1 objects/365.2 KiB.                                    \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7cVhvWToGO5v",
        "colab_type": "code",
        "outputId": "035bc544-7ca5-4a8d-8210-690025d5b7e5",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 275
        }
      },
      "source": [
        "mushroom_data = pd.read_csv('mushrooms.csv')\n",
        "mushroom_data.head()"
      ],
      "execution_count": 3,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>class</th>\n",
              "      <th>cap-shape</th>\n",
              "      <th>cap-surface</th>\n",
              "      <th>cap-color</th>\n",
              "      <th>bruises</th>\n",
              "      <th>odor</th>\n",
              "      <th>gill-attachment</th>\n",
              "      <th>gill-spacing</th>\n",
              "      <th>gill-size</th>\n",
              "      <th>gill-color</th>\n",
              "      <th>stalk-shape</th>\n",
              "      <th>stalk-root</th>\n",
              "      <th>stalk-surface-above-ring</th>\n",
              "      <th>stalk-surface-below-ring</th>\n",
              "      <th>stalk-color-above-ring</th>\n",
              "      <th>stalk-color-below-ring</th>\n",
              "      <th>veil-type</th>\n",
              "      <th>veil-color</th>\n",
              "      <th>ring-number</th>\n",
              "      <th>ring-type</th>\n",
              "      <th>spore-print-color</th>\n",
              "      <th>population</th>\n",
              "      <th>habitat</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>0</th>\n",
              "      <td>p</td>\n",
              "      <td>x</td>\n",
              "      <td>s</td>\n",
              "      <td>n</td>\n",
              "      <td>t</td>\n",
              "      <td>p</td>\n",
              "      <td>f</td>\n",
              "      <td>c</td>\n",
              "      <td>n</td>\n",
              "      <td>k</td>\n",
              "      <td>e</td>\n",
              "      <td>e</td>\n",
              "      <td>s</td>\n",
              "      <td>s</td>\n",
              "      <td>w</td>\n",
              "      <td>w</td>\n",
              "      <td>p</td>\n",
              "      <td>w</td>\n",
              "      <td>o</td>\n",
              "      <td>p</td>\n",
              "      <td>k</td>\n",
              "      <td>s</td>\n",
              "      <td>u</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>e</td>\n",
              "      <td>x</td>\n",
              "      <td>s</td>\n",
              "      <td>y</td>\n",
              "      <td>t</td>\n",
              "      <td>a</td>\n",
              "      <td>f</td>\n",
              "      <td>c</td>\n",
              "      <td>b</td>\n",
              "      <td>k</td>\n",
              "      <td>e</td>\n",
              "      <td>c</td>\n",
              "      <td>s</td>\n",
              "      <td>s</td>\n",
              "      <td>w</td>\n",
              "      <td>w</td>\n",
              "      <td>p</td>\n",
              "      <td>w</td>\n",
              "      <td>o</td>\n",
              "      <td>p</td>\n",
              "      <td>n</td>\n",
              "      <td>n</td>\n",
              "      <td>g</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>e</td>\n",
              "      <td>b</td>\n",
              "      <td>s</td>\n",
              "      <td>w</td>\n",
              "      <td>t</td>\n",
              "      <td>l</td>\n",
              "      <td>f</td>\n",
              "      <td>c</td>\n",
              "      <td>b</td>\n",
              "      <td>n</td>\n",
              "      <td>e</td>\n",
              "      <td>c</td>\n",
              "      <td>s</td>\n",
              "      <td>s</td>\n",
              "      <td>w</td>\n",
              "      <td>w</td>\n",
              "      <td>p</td>\n",
              "      <td>w</td>\n",
              "      <td>o</td>\n",
              "      <td>p</td>\n",
              "      <td>n</td>\n",
              "      <td>n</td>\n",
              "      <td>m</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>p</td>\n",
              "      <td>x</td>\n",
              "      <td>y</td>\n",
              "      <td>w</td>\n",
              "      <td>t</td>\n",
              "      <td>p</td>\n",
              "      <td>f</td>\n",
              "      <td>c</td>\n",
              "      <td>n</td>\n",
              "      <td>n</td>\n",
              "      <td>e</td>\n",
              "      <td>e</td>\n",
              "      <td>s</td>\n",
              "      <td>s</td>\n",
              "      <td>w</td>\n",
              "      <td>w</td>\n",
              "      <td>p</td>\n",
              "      <td>w</td>\n",
              "      <td>o</td>\n",
              "      <td>p</td>\n",
              "      <td>k</td>\n",
              "      <td>s</td>\n",
              "      <td>u</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>e</td>\n",
              "      <td>x</td>\n",
              "      <td>s</td>\n",
              "      <td>g</td>\n",
              "      <td>f</td>\n",
              "      <td>n</td>\n",
              "      <td>f</td>\n",
              "      <td>w</td>\n",
              "      <td>b</td>\n",
              "      <td>k</td>\n",
              "      <td>t</td>\n",
              "      <td>e</td>\n",
              "      <td>s</td>\n",
              "      <td>s</td>\n",
              "      <td>w</td>\n",
              "      <td>w</td>\n",
              "      <td>p</td>\n",
              "      <td>w</td>\n",
              "      <td>o</td>\n",
              "      <td>e</td>\n",
              "      <td>n</td>\n",
              "      <td>a</td>\n",
              "      <td>g</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "  class cap-shape cap-surface  ... spore-print-color population habitat\n",
              "0     p         x           s  ...                 k          s       u\n",
              "1     e         x           s  ...                 n          n       g\n",
              "2     e         b           s  ...                 n          n       m\n",
              "3     p         x           y  ...                 k          s       u\n",
              "4     e         x           s  ...                 n          a       g\n",
              "\n",
              "[5 rows x 23 columns]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 3
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZITDo6upbV-C",
        "colab_type": "text"
      },
      "source": [
        "To keep things simple, we'll first convert the label column to numeric and then \n",
        "use `pd.get_dummies()` to covert the data to numeric. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "_W5aElS8H5g5",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# 1 = edible, 0 = poisonous\n",
        "mushroom_data.loc[mushroom_data['class'] == 'p', 'class'] = 0\n",
        "mushroom_data.loc[mushroom_data['class'] == 'e', 'class'] = 1"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oP6HGdFZHemb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "labels = mushroom_data.pop('class')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fTjqVF79G67V",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "dummy_data = pd.get_dummies(mushroom_data)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "t4VWcj1hHrZ6",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Split the data\n",
        "train_size = int(len(mushroom_data) * .8)\n",
        "\n",
        "train_data = dummy_data[:train_size]\n",
        "test_data = dummy_data[train_size:]\n",
        "\n",
        "train_labels = labels[:train_size]\n",
        "test_labels = labels[train_size:]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Zl6cTbIKZUNb",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model = keras.Sequential([\n",
        "    keras.layers.Dense(32, input_shape=(len(dummy_data.iloc[0]),), activation='relu'),\n",
        "    keras.layers.Dense(8, activation='relu'),\n",
        "    keras.layers.Dense(1, activation='sigmoid')\n",
        "])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-GIp77GfAnqm",
        "colab_type": "code",
        "outputId": "8d021865-1287-48db-8dd3-596ba7d87e6e",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 255
        }
      },
      "source": [
        "model.summary()"
      ],
      "execution_count": 9,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"sequential\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "dense (Dense)                (None, 32)                3776      \n",
            "_________________________________________________________________\n",
            "dense_1 (Dense)              (None, 8)                 264       \n",
            "_________________________________________________________________\n",
            "dense_2 (Dense)              (None, 1)                 9         \n",
            "=================================================================\n",
            "Total params: 4,049\n",
            "Trainable params: 4,049\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "x7J8s6YJIXT1",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# Since we're using sigmoid output, we use binary_crossentropy for our loss function\n",
        "model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1EH9AWMqIqsH",
        "colab_type": "code",
        "outputId": "834c6a26-0547-432e-9d03-c3e1d2348e51",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "model.fit(train_data.values.tolist(), train_labels.values.tolist())"
      ],
      "execution_count": 43,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "204/204 [==============================] - 0s 1ms/step - loss: 0.0018 - accuracy: 1.0000\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<tensorflow.python.keras.callbacks.History at 0x7fc9f4f3fcc0>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 43
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "aLZ7dyOpKalH",
        "colab_type": "code",
        "outputId": "a557f70d-0404-4802-c1e2-bc9a1e45db38",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 51
        }
      },
      "source": [
        "model.evaluate(test_data.values.tolist(), test_labels.values.tolist())"
      ],
      "execution_count": 44,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "51/51 [==============================] - 0s 1ms/step - loss: 0.0310 - accuracy: 0.9865\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.031002074480056763, 0.9864615201950073]"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 44
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "kxXsRzpsKyyT",
        "colab_type": "text"
      },
      "source": [
        "#### Sidebar: for comparison, let's train the same model but use a 2-element softmax output layer.\n",
        "\n",
        "**This is an anti-pattern. It's better to use sigmoid for binary classification.**\n",
        "\n",
        "Note the increased number of trainable parameters on the output layer for each model. You can imagine how this could increase training time for larger models."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BZ7p7n8jcsgM",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# First, transform the label column to one-hot\n",
        "def to_one_hot(data):\n",
        "  if data == 0:\n",
        "    return [1, 0]\n",
        "  else: \n",
        "    return [0,1]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "KY7d2vNbfJP9",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "train_labels_one_hot = train_labels.apply(to_one_hot)\n",
        "test_labels_one_hot = test_labels.apply(to_one_hot)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6XsuADG-A2Kv",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model_softmax = keras.Sequential([\n",
        "    keras.layers.Dense(32, input_shape=(len(dummy_data.iloc[0]),), activation='relu'),\n",
        "    keras.layers.Dense(8, activation='relu'),\n",
        "    keras.layers.Dense(2, activation='softmax')\n",
        "])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Y-X2zWGKA9pD",
        "colab_type": "code",
        "outputId": "efe62438-21fa-4c17-afcb-68a06d2cc813",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 255
        }
      },
      "source": [
        "model_softmax.summary()"
      ],
      "execution_count": 35,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Model: \"sequential_1\"\n",
            "_________________________________________________________________\n",
            "Layer (type)                 Output Shape              Param #   \n",
            "=================================================================\n",
            "dense_3 (Dense)              (None, 32)                3776      \n",
            "_________________________________________________________________\n",
            "dense_4 (Dense)              (None, 8)                 264       \n",
            "_________________________________________________________________\n",
            "dense_5 (Dense)              (None, 2)                 18        \n",
            "=================================================================\n",
            "Total params: 4,058\n",
            "Trainable params: 4,058\n",
            "Non-trainable params: 0\n",
            "_________________________________________________________________\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c8JshprEA-hf",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model_softmax.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "oNT9mBAag91L",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model_softmax.fit(train_data.values.tolist(), train_labels_one_hot.values.tolist())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zyggpkuvhS4S",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "model_softmax.evaluate(test_data.values.tolist(), test_labels_one_hot.values.tolist())"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Copyright 2020 Google Inc. Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License"
      ]
    }
  ]
}